#!/usr/bin/lua
--[[
 *
 * Copyright (C) 2023 hanwckf <hanwckf@vip.qq.com>
 *
 * 	This program is free software; you can redistribute it and/or modify
 * 	it under the terms of the GNU General Public License as published by
 * 	The Free Software Foundation; either version 2 of the License, or
 * 	(at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
]]

local cjson = require "cjson"
local inspect = require "inspect"
local nixio = require "nixio"
local l1parser = require "l1dat_parser"
local datconf = require "datconf"
local utils = require "mtwifi_utils"
local defs = require "mtwifi_defs"

local l1dat = l1parser.load_l1_profile(l1parser.L1_DAT_PATH)

function string:split(sep)
    local sep, fields = sep or ":", {}
    local pattern = string.format("([^%s]+)", sep)
    self:gsub(pattern, function(c) fields[#fields+1] = c end)
    return fields
end

function load_profile(path)
    local cfgs = {}

    cfgobj = datconf.openfile(path)
    if cfgobj then
        cfgs = cfgobj:getall()
        cfgobj:close()
    end

    return cfgs
end

function save_profile(cfgs, path)
    if not cfgs then
        return
    end

    local datobj = datconf.openfile(path)
    datobj:merge(cfgs)
    datobj:close(true)

    os.execute("sync")
end

function diff_cfg(cfg1, cfg2)
    local diff = {}

    for k,v in pairs(cfg1) do
        if tostring(cfg2[k]) ~= tostring(cfg1[k]) then
            diff[k] = {cfg1[k] or "", cfg2[k] or ""}
        end
    end

    for k,v in pairs(cfg2) do
        if tostring(cfg2[k]) ~= tostring(cfg1[k]) then
            diff[k] = {cfg1[k] or "", cfg2[k] or ""}
        end
    end

    return diff
end

function vif_status(ifname)
    local flags = tonumber(utils.read_pipe("cat /sys/class/net/"..ifname.."/flags 2>/dev/null")) or 0
    if flags%2 == 1 then
        return "up"
    end
    return "down"
end

function is_dbdc_inited(ifname)
    local mac = utils.trim(utils.read_pipe("cat /sys/class/net/"..ifname.."/address 2>/dev/null"))
    if mac and mac ~= "00:00:00:00:00:00" then
        return true
    end
    return false
end

function ifup(ifname)
    if ifname and #ifname > 0 then
        os.execute("ifconfig "..ifname.." up")
        nixio.syslog("info", "mtwifi-cfg: ifconfig up "..ifname)
    end
end

function ifdown(ifname)
    if ifname and #ifname > 0 then
        os.execute("ifconfig "..ifname.." down")
        nixio.syslog("info", "mtwifi-cfg: ifconfig down "..ifname)
    end
end

function cfg2dat(cfg)
    if type(cfg) == type(true) then
        if cfg then return 1 else return 0 end
    elseif type(cfg) == type(0) or type(cfg) == type("") then
        return cfg
    end
end

function set_dev_dat(dats, dat, cfg)
    if cfg ~= nil then
        dats[dat] = cfg2dat(cfg)
    end
end

function set_dat(dats, apidx, dat, cfg)
    if cfg ~= nil then
        dats[dat] = utils.token_set(dats[dat], apidx, cfg2dat(cfg))
    end
end

function set_idx_dat(dats, apidx, dat, cfg)
    if cfg ~= nil then
        dats[dat..tostring(apidx)] = cfg
    end
end

function vif_in_dev(vif, dev)
    if (vif == dev.main_ifname
        or string.match(vif, utils.esc(dev.ext_ifname).."[0-9]+")
        or string.match(vif, utils.esc(dev.apcli_ifname).."[0-9]+")
        or string.match(vif, utils.esc(dev.wds_ifname).."[0-9]+")
        or string.match(vif, utils.esc(dev.mesh_ifname).."[0-9]+")) then
        return true
    else
        return false
    end
end

function vif_is_apcli(vif, dev)
    if (string.match(vif, utils.esc(dev.apcli_ifname).."[0-9]+")) then
        return true
    else
        return false
    end
end

function is_ax_mode(cfg)
    if not cfg.config.htmode then
        return false
    end
    if string.sub(cfg.config.htmode,1,2) == "HE" then
        return true
    end
    return false
end

function vif_count(cfg)
    local vif_num = 0
    local ap_num = 0
    local apcli_num = 0

    for k,v in pairs(cfg.interfaces) do
        vif_num = vif_num + 1
        if v.config.mode == "ap" then
            ap_num = ap_num + 1
        elseif v.config.mode == "sta" then
            apcli_num = apcli_num + 1
        end
    end

    return vif_num, ap_num, apcli_num
end

function mtwifi_cfg_foreach_vif(cfg, mode, func)
    if cfg == nil or func == nil then return end

    local vif_num,_,_ = vif_count(cfg)
    for idx=0,vif_num-1 do
        local v = cfg.interfaces[string.format("%02d", idx)]
        if v.config.mode and v.config.mode == mode then
            func(v)
        end
    end
end

function __apcli_auto_connect(ifname)
    __exec_iwpriv_cmd(ifname, "ApCliEnable", "1")
    __exec_iwpriv_cmd(ifname, "ApCliAutoConnect", "3")
end

function mtwifi_up(devname, cfg, restore_vifs, is_dbdc)
    nixio.syslog("info", "mtwifi-cfg: start "..devname)
    local dev = l1dat.devname_ridx[devname]
    if not dev then return end

    if is_dbdc then
        local dbdc_main_dev = dev.INDEX .. "_1_1"
        local dbdc_main_ifname = l1dat.devname_ridx[dbdc_main_dev].main_ifname

        if not is_dbdc_inited(dbdc_main_ifname) then
            nixio.syslog("info", "mtwifi-cfg: dbdc card init...")
            ifup(dbdc_main_ifname)
            utils.sleep(1)
            ifdown(dbdc_main_ifname)
        end

        if restore_vifs and devname ~= cfg.device then
            -- restart ap vif
            for _,vif in ipairs(restore_vifs) do
                if vif_in_dev(vif, dev) and (not vif_is_apcli(vif, dev)) then
                    nixio.syslog("info", "mtwifi-cfg: restore ap vif: "..vif)
                    ifup(vif)
                end
            end
            -- restart apcli vif
            for _,vif in ipairs(restore_vifs) do
                if vif_is_apcli(vif, dev) then
                    nixio.syslog("info", "mtwifi-cfg: restore apcli vif: "..vif)
                    ifup(vif)
                    __apcli_auto_connect(vif)
                end
            end
            return
        end
    end

    local start_vif = function(v)
        if (not v.config.disabled) and v.mtwifi_ifname then
            local vif = v.mtwifi_ifname
            nixio.syslog("info", "mtwifi-cfg: up vif: "..vif)
            ifup(vif)
        end
    end

    mtwifi_cfg_foreach_vif(cfg, "ap", start_vif)
    mtwifi_cfg_foreach_vif(cfg, "sta", start_vif)
end

function mtwifi_down(devname, cfg)
    nixio.syslog("info", "mtwifi-cfg: stop "..devname)
    local dev = l1dat.devname_ridx[devname]
    local reset_vifs = {}

    if not dev then return end

    for _,vif in ipairs(string.split(utils.read_pipe("ls /sys/class/net"), "\n")) do
        if (vif_status(vif) == "up" and vif_in_dev(vif, dev)) then
            nixio.syslog("info", "mtwifi-cfg: down vif: "..vif)
            ifdown(vif)
            if cfg and devname ~= cfg.device then
                reset_vifs[#reset_vifs+1] = vif
            end
        end
    end
    return reset_vifs
end

function mtwifi_reinstall()
    nixio.syslog("info", "mtwifi-cfg: unload mtwifi module...")

    --os.execute("rmmod mt_whnat")
    --os.execute("rmmod mtfwd")
    os.execute("rmmod mtk_warp_proxy")
    os.execute("rmmod mtk_warp")
    --os.execute("rmmod mt7915_mt_wifi")
    os.execute("rmmod mt_wifi")

    utils.sleep(2)

    nixio.syslog("info", "mtwifi-cfg: reload mtwifi module...")

    os.execute("modprobe mt_wifi")
    --os.execute("modprobe mt7915_mt_wifi")
    os.execute("modprobe mtk_warp")
    os.execute("modprobe mtk_warp_proxy")
    --os.execute("modprobe mtfwd")
    --os.execute("modprobe mt_whnat")

    utils.sleep(1)
end

function __exec_iwpriv_cmd(ifname, key, val)
    local cmd = string.format("iwpriv %s set %s=%s", ifname, key, tostring(val))
    nixio.syslog("info", "mtwifi-cfg: iwpriv cmd: "..cmd)
    os.execute(cmd)
end

function mtwifi_cfg_iwpriv_hook(cfg)
    local ap_hook = function(v)
        if (not v.config.disabled) and v.mtwifi_ifname then
            local vif = v.mtwifi_ifname
            for k, j in pairs(defs.iwpriv_ap_cfgs) do
                __exec_iwpriv_cmd(vif, j[1], v.config[k] or j[2])
            end
        end
    end

    local apcli_hook = function(v)
        if (not v.config.disabled) and v.mtwifi_ifname then
            local vif = v.mtwifi_ifname
            __apcli_auto_connect(vif)
        end
    end

    mtwifi_cfg_foreach_vif(cfg, "ap", ap_hook)
    mtwifi_cfg_foreach_vif(cfg, "sta", apcli_hook)
end

function set_chip_cfg(cfg, dats)
    local dev_dats
    local uci
    local dat_key

    if not cfg.config.dbdc_main then return end

    for k,v in pairs(defs.chip_cfgs) do
        uci = k
        dat_key = v[1]

        -- reset chip cfgs to default
        dats[dat_key] = v[2]

        -- setup current dev cfg
        set_dev_dat(dats, dat_key, cfg.config[uci])
    end

    -- setup other dev cfg
    for k,v in pairs(l1dat.devname_ridx) do
        if k ~= cfg.device and 
        v.INDEX == l1dat.devname_ridx[cfg.device].INDEX and 
        v.mainidx == l1dat.devname_ridx[cfg.device].mainidx then        
            dev_dats = load_profile(v.profile_path)
            if dev_dats then
                for _,j in pairs(defs.chip_cfgs) do
                    dat_key = j[1]
                    set_dev_dat(dev_dats, dat_key, dats[dat_key])
                end
                save_profile(dev_dats, v.profile_path)
            end
        end
    end
end

function mtwifi_cfg_setup(argv)
    local cfg = cjson.decode(argv)

    utils.log2file("input = " .. inspect(cfg))

    local devname = cfg.device
    local dev = l1dat.devname_ridx[devname]

    if not dev then return end

    local profile = dev.profile_path

    local dats = load_profile(profile)
    if not dats then
        nixio.syslog("err", "mtwifi-cfg: profile ".. profile .. "open failed")
        return
    end
    local dats_orig = load_profile(profile)

    local vif_num, bssid_num, apcli_num = vif_count(cfg)

    if vif_num == 0 then
        nixio.syslog("err", "mtwifi-cfg: not valid vif found!")
        return
    end
    if bssid_num > defs.max_mbssid then
        nixio.syslog("err", "mtwifi-cfg: too many vifs!")
        return
    end
    if apcli_num > 1 then
        nixio.syslog("err", "mtwifi-cfg: only support one apcli vif!")
    end

    if bssid_num > 0 then
        dats.BssidNum = bssid_num
    else
        dats.BssidNum = 1
    end

    -- setup wireless mode
    if type(cfg.config.htmode) ~= type("") then
        nixio.syslog("err", "mtwifi-cfg: invalid htmode!")
        return
    end

    local WirelessMode

    if cfg.config.band == "2g" then
        WirelessMode = 9 -- PHY_11BGN_MIXED
        if is_ax_mode(cfg) then
            WirelessMode = 16 -- PHY_11AX_24G
        end
    elseif cfg.config.band == "5g" then
        WirelessMode = 15 -- PHY_11VHT_N_MIXED
        if is_ax_mode(cfg) then
            WirelessMode = 17 -- PHY_11AX_5G
        end
    elseif cfg.config.band == "6g" then
        if is_ax_mode(cfg) then
            WirelessMode = 18 -- PHY_11AX_6G
        end
    end

    if not WirelessMode then
        nixio.syslog("err", "mtwifi-cfg: invalid wireless mode!")
        return
    end

    -- setup bandwidth
    local bw = string.match(cfg.config.htmode, "%d+")
    if not bw then
        nixio.syslog("err", "mtwifi-cfg: invalid bandwidth!")
        return
    end

    -- default bw is 20MHz
    dats.HT_BW = 0
    dats.VHT_BW = 0

    if bw == "40" then
        dats.HT_BW = 1
        dats.VHT_BW = 0
        if cfg.config.noscan ~= nil and cfg.config.noscan == "1" then
            dats.HT_BSSCoexistence = 0
        else
            dats.HT_BSSCoexistence = 1
        end
    elseif bw == "80" then
        dats.HT_BW = 1
        dats.VHT_BW = 1
    elseif bw == "160" then
        dats.HT_BW = 1
        dats.VHT_BW = 2
    end

    -- setup apcli
    dats.ApCliEnable = 0
    dats.ApCliSsid = ""
    dats.ApCliBssid = ""
    dats.ApCliAuthMode = ""
    dats.ApCliEncrypType = ""
    dats.ApCliWPAPSK = ""
    dats.ApCliPMFMFPC = 0
    dats.ApCliPMFMFPR = 0
    dats.ApCliPMFSHA256 = 0

    for _,v in pairs(cfg.interfaces) do
        if v.config.mode == "sta" then
            if not v.config.disabled then
                dats.ApCliEnable = 1
            end
            dats.ApCliSsid = v.config.ssid or ""
            dats.ApCliBssid = v.config.bssid or ""
            dats.ApCliAuthMode = defs.enc2dat[v.config.encryption][1]
            dats.ApCliEncrypType = defs.enc2dat[v.config.encryption][2]
            dats.ApCliWPAPSK = v.config.key or ""

            if dats.ApCliAuthMode == "OWE" or dats.ApCliAuthMode == "WPA3PSK" then
                dats.ApCliPMFMFPC = 1
                dats.ApCliPMFMFPR = 1
                dats.ApCliPMFSHA256 = 0
            end
            break
        end
    end

    -- setup chip cfgs
    set_chip_cfg(cfg, dats)

    -- setup other dev cfgs
    if type(cfg.config.country) == type("") and #cfg.config.country == 2 then
        dats.CountryCode = cfg.config.country
        if cfg.config.band == "2g" then
            dats.CountryRegion = defs.countryRegions[cfg.config.country][1]
        elseif cfg.config.band == "5g" or cfg.config.band == "6g" then
            dats.CountryRegionABand = defs.countryRegions[cfg.config.country][2]
        end
    end

    if cfg.config.channel == "auto" then
        dats.AutoChannelSelect = 3
        dats.Channel = 0
    else
        dats.AutoChannelSelect = 0
        dats.Channel = cfg.config.channel
    end

    if cfg.config.txpower and cfg.config.txpower < 100 then
        dats.PERCENTAGEenable = 1
        dats.TxPower = cfg.config.txpower
    else
        dats.PERCENTAGEenable = 0
        dats.TxPower = 100
    end

    if cfg.config.mu_beamformer then
        dats.ETxBfEnCond = 1
        if dats.ApCliEnable == 1 then
            dats.MUTxRxEnable = 3
        else
            dats.MUTxRxEnable = 1
        end
        dats.ITxBfEn = 0
    else
        dats.ETxBfEnCond = 0
        dats.MUTxRxEnable = 0
        dats.ITxBfEn = 0
    end

    if is_ax_mode(cfg) and cfg.config.twt then
        dats.TWTSupport = cfg.config.twt
    else
        dats.TWTSupport = 0
    end

    -- reset vif cfgs to default
    for k,v in pairs(defs.vif_cfgs) do
        dats[k] = ""
        for i = 1, bssid_num do
            dats[k] = utils.token_set(dats[k], i, v)
        end
    end

    for k,v in pairs(defs.vif_cfgs_idx) do
        for i = 1, defs.max_mbssid do
            dats[k..tostring(i)] = ""
        end
        for i = 1, bssid_num do
            dats[k..tostring(i)] = v
        end
    end

    for k,v in pairs(defs.vif_acl) do
        for i = 0, defs.max_mbssid-1 do
            dats[k..tostring(i)] = ""
        end
        for i = 0, bssid_num-1 do
            dats[k..tostring(i)] = v
        end
    end

    -- setup vif cfgs
    local apidx = 1
    for idx = 0,vif_num-1 do
        v = cfg.interfaces[string.format("%02d", idx)]
        if v.config.mode == "ap" then
            set_idx_dat(dats, apidx, "SSID", v.config.ssid)
            set_idx_dat(dats, apidx, "WPAPSK", v.config.key or "")
            set_dat(dats, apidx, "NoForwarding", v.config.isolate)
            set_dat(dats, apidx, "HideSSID", v.config.hidden)
            set_dat(dats, apidx, "WmmCapable", v.config.wmm)
            set_dat(dats, apidx, "RRMEnable", v.config.ieee80211k)
            set_dat(dats, apidx, "RekeyInterval", v.config.wpa_group_rekey)
            set_dat(dats, apidx, "MuMimoDlEnable", v.config.mumimo_dl)
            set_dat(dats, apidx, "MuMimoUlEnable", v.config.mumimo_ul)
            set_dat(dats, apidx, "MuOfdmaDlEnable", v.config.ofdma_dl)
            set_dat(dats, apidx, "MuOfdmaUlEnable", v.config.ofdma_ul)
            set_dat(dats, apidx, "HT_AMSDU", v.config.amsdu)
            set_dat(dats, apidx, "HT_AutoBA", v.config.autoba)
            set_dat(dats, apidx, "APSDCapable", v.config.uapsd)
            set_dat(dats, apidx, "RTSThreshold", v.config.rts)
            set_dat(dats, apidx, "FragThreshold",v.config.frag)
            set_dat(dats, apidx, "DtimPeriod",v.config.dtim_period)
            set_dat(dats, apidx, "WirelessMode", WirelessMode)

            if is_ax_mode(cfg) then
                set_dat(dats, apidx, "HT_BAWinSize", 256)
            else
                set_dat(dats, apidx, "HT_BAWinSize", 64)
            end

            if v.config.macfilter then
                if v.config.macfilter == "allow" then
                    set_idx_dat(dats, apidx-1, "AccessPolicy", 1)
                elseif v.config.macfilter == "deny" then
                    set_idx_dat(dats, apidx-1, "AccessPolicy", 2)
                end
            end

            if v.config.maclist then
                local maclist = ""
                for i, v in ipairs(v.config.maclist) do
                    if i > defs.max_acl_entry then 
                        break
                    end
                    if #maclist > 0 then
                        maclist = maclist .. ";" .. v
                    else
                        maclist = v
                    end
                end
                set_idx_dat(dats, apidx-1, "AccessControlList", maclist)
            end

            local authmode = defs.enc2dat[v.config.encryption][1]
            set_dat(dats, apidx, "AuthMode", authmode)
            set_dat(dats, apidx, "EncrypType", defs.enc2dat[v.config.encryption][2])

            if authmode == "OWE" or authmode == "WPA3PSK" then
                set_dat(dats, apidx, "PMFMFPC", 1)
                set_dat(dats, apidx, "PMFMFPR", 1)
                set_dat(dats, apidx, "PMFSHA256", 0)
            elseif authmode == "WPA2PSKWPA3PSK" then
                set_dat(dats, apidx, "PMFMFPC", 1)
                set_dat(dats, apidx, "PMFMFPR", 0)
                set_dat(dats, apidx, "PMFSHA256", 0)
            end

            if not (authmode == "OPEN" or authmode == "OWE") then
                set_dat(dats, apidx, "RekeyMethod", "TIME")
            end

            apidx = apidx + 1
        end
    end

    local reinstall_wifidrv = false
    local cfg_diff = diff_cfg(dats_orig, dats)
    utils.log2file("diff = " .. inspect(cfg_diff))

    save_profile(dats, profile)

    for _,v in pairs(defs.reinstall_cfgs) do
        if cfg_diff[v] ~= nil then
            if utils.exists("/sys/module/mt_wifi") == false then
                nixio.syslog("err", "mtwifi-cfg: mtwifi module is build-in, please reboot the device!")
                return
            end
            reinstall_wifidrv = true
        end
    end

    if string.find(profile, "dbdc") then
        if reinstall_wifidrv then
            for k, _ in pairs(l1dat.devname_ridx) do
                mtwifi_down(k)
            end

            mtwifi_reinstall()

            for k, _ in pairs(l1dat.devname_ridx) do
                if k == devname then
                    mtwifi_up(k, cfg, nil, true)
                else
                    os.execute("/sbin/wifi up " .. k) -- TODO: may cause deadlock
                end
            end
        else
            local restore_vifs = {}
            local restart_other_dbdc_dev = false
            if next(cfg_diff) ~= nil then
                restart_other_dbdc_dev = true
            end

            if restart_other_dbdc_dev then
                for k, j in pairs(l1dat.devname_ridx) do
                    if j.INDEX == dev.INDEX and j.mainidx == dev.mainidx then
                        local ret = mtwifi_down(k, cfg)
                        for _, v in ipairs(ret) do
                            table.insert(restore_vifs, v)
                        end
                    end
                end
                if #restore_vifs > 0 then
                    nixio.syslog("info", "mtwifi-cfg: dbdc restore_vifs: "..inspect.inspect(restore_vifs))
                end

                for k, j in pairs(l1dat.devname_ridx) do
                    if j.INDEX == dev.INDEX and j.mainidx == dev.mainidx then
                        mtwifi_up(k, cfg, restore_vifs, true)
                    end
                end
            else
                mtwifi_down(devname)
                mtwifi_up(devname, cfg, nil, true)
            end
        end
    else
        mtwifi_down(devname)

        if reinstall_wifidrv then
            mtwifi_reinstall()
        end

        mtwifi_up(devname, cfg)
    end

    mtwifi_cfg_iwpriv_hook(cfg)
end

function mtwifi_cfg_down(devname)
    if devname then
        mtwifi_down(devname)
    else
        for k, _ in pairs(l1dat.devname_ridx) do
            mtwifi_down(k)
        end
    end
end

local action = {
    ["down"] = function(devname)
        mtwifi_cfg_down(devname)
    end,

    ["setup"] = function()
        local argv = io.read()
        if #argv > 0 then
            mtwifi_cfg_setup(argv)
        end
    end
}

if #arg == 1 then
    action[arg[1]]()
elseif #arg == 2 then
    action[arg[1]](arg[2])
end
